import numpy as np

"""
激活函数 Relu
当x＞0时，输出x;
当x<=0时，输出0

"""


class Relu:

    def __init__(self):
        self.mask = None

    def forward(self, x):
        # mask 是有 Ture 和 False 构成的
        self.mask = (x <= 0)
        out = x.copy()
        # 对 self.mask 为 true的，赋为0
        out[self.mask] = 0
        return out

    def backward(self, dout):
        dout[self.mask] = 0
        dx = dout
        return dx


if __name__ == '__main__':
    x = Relu()
    arr = np.array([1, 2, 3, 4, -1])

    out = x.forward(arr)
    print(out)
